/*
 * Wazuh Vulnerability scanner - Scan Orchestrator
 * Copyright (C) 2015, Wazuh Inc.
 * May 1, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#ifndef _OS_SCANNER_HPP
#define _OS_SCANNER_HPP

#include "chainOfResponsability.hpp"
#include "databaseFeedManager.hpp"
#include "scanContext.hpp"
#include "scannerHelper.hpp"
#include "versionMatcher/versionMatcher.hpp"
#include "wdbDataException.hpp"

auto constexpr OS_SCANNER_CNA {"nvd"};

/**
 * @brief OsScanner class.
 * This class is in charge of scanning the OS for vulnerabilities.
 * It receives the scan context and the database feed manager and returns the scan context with the vulnerabilities
 * found. The OS scanner is in charge of scanning the OS for vulnerabilities and updating the scan context with the
 * vulnerabilities found. The OS scanner is also in charge of updating the scan context with the match conditions for
 * the vulnerabilities found.
 *
 */
template<typename TDatabaseFeedManager = DatabaseFeedManager,
         typename TScanContext = ScanContext,
         typename TGlobalData = GlobalData,
         typename TSocketDBWrapper = SocketDBWrapper>
class TOsScanner final : public AbstractHandler<std::shared_ptr<TScanContext>>
{
private:
    std::shared_ptr<TDatabaseFeedManager> m_databaseFeedManager;

public:
    /**
     * @brief OsScanner constructor.
     *
     * @param databaseFeedManager Database feed manager.
     */
    explicit TOsScanner(std::shared_ptr<TDatabaseFeedManager> databaseFeedManager)
        : m_databaseFeedManager(std::move(databaseFeedManager))
    {
    }
    /**
     * @brief Handles request and passes control to the next step of the chain.
     *
     * @param data Scan context.
     * @return std::shared_ptr<ScanContext> Abstract handler.
     */
    // LCOV_EXCL_START
    std::shared_ptr<TScanContext> handleRequest(std::shared_ptr<TScanContext> data) override
    {
        nlohmann::json responseHotfixes;

        if (data->osPlatform() == "windows")
        {
            try
            {
                TSocketDBWrapper::instance().query(
                    WazuhDBQueryBuilder::builder().agentGetHotfixesCommand(data->agentId().data()).build(),
                    responseHotfixes);
            }
            catch (const SocketDbWrapperException& e)
            {
                throw WdbDataException(e.what(), data->agentId());
            }
            catch (const std::exception& e)
            {
                logError(WM_VULNSCAN_LOGTAG,
                         "Unable to retrieve hotfixes for agent %s. Reason: %s.",
                         data->agentId().data(),
                         e.what());
                return nullptr;
            }
        }

        const auto osCPE = ScannerHelper::parseCPE(data->osCPEName().data());

        auto vulnerabilityScan = [&](const std::string& cnaName,
                                     [[maybe_unused]] const PackageData& package,
                                     const NSVulnerabilityScanner::ScanVulnerabilityCandidate& callbackData)
        {
            try
            {
                // We don't override the vulnerability if it was detected before.
                auto cveId = callbackData.cveId()->str();
                if (data->m_elements.find(cveId) != data->m_elements.end())
                {
                    logDebug1(WM_VULNSCAN_LOGTAG,
                              "CVE '%s' already found by a higher priority CNA. Skipping.",
                              cveId.c_str());
                    return true;
                }

                const std::string osVersion {data->osVersion()};
                std::variant<VersionObjectType, VersionMatcherStrategy> objectType = VersionObjectType::DPKG;
                auto osVersionObject = VersionMatcher::createVersionObject(osVersion, objectType);

                for (const auto& version : *callbackData.versions())
                {
                    std::string versionString {version->version() ? version->version()->str() : ""};
                    std::string versionStringLessThan {version->lessThan() ? version->lessThan()->str() : ""};
                    std::string versionStringLessThanOrEqual {
                        version->lessThanOrEqual() ? version->lessThanOrEqual()->str() : ""};

                    logDebug2(WM_VULNSCAN_LOGTAG,
                              "Scanning OS - '%s' (Installed Version: %s, Security Vulnerability: %s). Identified "
                              "vulnerability: "
                              "Version: %s. Required Version Threshold: %s. Required Version Threshold (or Equal): %s.",
                              osCPE.product.c_str(),
                              osVersion.c_str(),
                              cveId.c_str(),
                              versionString.c_str(),
                              versionStringLessThan.c_str(),
                              versionStringLessThanOrEqual.c_str());

                    // No version range specified, check if the installed version is equal to the required version.
                    if (versionStringLessThan.empty() && versionStringLessThanOrEqual.empty())
                    {
                        if (VersionMatcher::compare(osVersionObject, osVersion, versionString, objectType) ==
                            VersionComparisonResult::A_EQUAL_B)
                        {
                            // Version match found, the package status is defined by the vulnerability status.
                            if (version->status() == NSVulnerabilityScanner::Status::Status_affected)
                            {
                                logDebug1(WM_VULNSCAN_LOGTAG,
                                          "Match found, the OS '%s', is vulnerable to '%s'. Current version: '%s' is "
                                          "equal to '%s'. - Agent '%s' (ID: '%s', Version: '%s').",
                                          osCPE.product.c_str(),
                                          cveId.c_str(),
                                          osVersion.c_str(),
                                          versionString.c_str(),
                                          data->agentName().data(),
                                          data->agentId().data(),
                                          data->agentVersion().data());

                                data->m_elements[cveId] = nlohmann::json::object();
                                data->m_matchConditions[cveId] = {std::move(versionString), MatchRuleCondition::Equal};
                                data->m_cnaDetectionSource[cveId] = cnaName;
                                return true;
                            }

                            return false;
                        }
                    }
                    else
                    {
                        // Version range specified
                        // Check if the installed version satisfies the lower bound of the version range.
                        auto lowerBoundMatch = false;
                        if (versionString.compare("0") == 0)
                        {
                            lowerBoundMatch = true;
                        }
                        else
                        {
                            const auto matchResult =
                                VersionMatcher::compare(osVersionObject, osVersion, versionString, objectType);
                            lowerBoundMatch = matchResult == VersionComparisonResult::A_GREATER_THAN_B ||
                                              matchResult == VersionComparisonResult::A_EQUAL_B;
                        }

                        if (lowerBoundMatch)
                        {
                            // Check if the installed version satisfies the upper bound of the version range.
                            auto upperBoundMatch = false;
                            if (!versionStringLessThan.empty() && versionStringLessThan.compare("*") != 0)
                            {
                                const auto matchResult = VersionMatcher::compare(
                                    osVersionObject, osVersion, versionStringLessThan, objectType);
                                upperBoundMatch = matchResult == VersionComparisonResult::A_LESS_THAN_B;
                            }
                            else if (!versionStringLessThanOrEqual.empty())
                            {
                                const auto matchResult = VersionMatcher::compare(
                                    osVersionObject, osVersion, versionStringLessThanOrEqual, objectType);
                                upperBoundMatch = matchResult == VersionComparisonResult::A_LESS_THAN_B ||
                                                  matchResult == VersionComparisonResult::A_EQUAL_B;
                            }
                            else
                            {
                                upperBoundMatch = false;
                            }

                            if (upperBoundMatch)
                            {
                                // Version match found, the package status is defined by the vulnerability status.
                                if (version->status() == NSVulnerabilityScanner::Status::Status_affected)
                                {
                                    logDebug1(
                                        WM_VULNSCAN_LOGTAG,
                                        "Match found, the OS '%s', is vulnerable to '%s'. Current version: "
                                        "'%s' ("
                                        "less than '%s' or equal to '%s'). - Agent '%s' (ID: '%s', Version: '%s').",
                                        osCPE.product.c_str(),
                                        cveId.c_str(),
                                        osVersion.c_str(),
                                        versionStringLessThan.c_str(),
                                        versionStringLessThanOrEqual.c_str(),
                                        data->agentName().data(),
                                        data->agentId().data(),
                                        data->agentVersion().data());

                                    data->m_elements[cveId] = nlohmann::json::object();
                                    data->m_cnaDetectionSource[cveId] = cnaName;

                                    if (!versionStringLessThanOrEqual.empty())
                                    {
                                        data->m_matchConditions[cveId] = {std::move(versionStringLessThanOrEqual),
                                                                          MatchRuleCondition::LessThanOrEqual};
                                    }
                                    else
                                    {
                                        data->m_matchConditions[cveId] = {std::move(versionStringLessThan),
                                                                          MatchRuleCondition::LessThan};
                                    }
                                    return true;
                                }
                                else
                                {
                                    logDebug2(WM_VULNSCAN_LOGTAG,
                                              "No match due to default status for OS: %s, Version: %s while scanning "
                                              "for Vulnerability: %s, "
                                              "Installed Version: %s, Required Version Threshold: %s, Required Version "
                                              "Threshold (or Equal): %s",
                                              osCPE.product.c_str(),
                                              osVersion.c_str(),
                                              cveId.c_str(),
                                              versionString.c_str(),
                                              versionStringLessThan.c_str(),
                                              versionStringLessThanOrEqual.c_str());

                                    return false;
                                }
                            }
                        }
                    }
                }

                // No match found, the default status defines the package status.
                if (callbackData.defaultStatus() == NSVulnerabilityScanner::Status::Status_affected)
                {
                    logDebug1(WM_VULNSCAN_LOGTAG,
                              "Match found for OS: %s for vulnerability: %s due to default status.",
                              osCPE.product.c_str(),
                              cveId.c_str());

                    data->m_elements[cveId] = nlohmann::json::object();
                    data->m_cnaDetectionSource[cveId] = cnaName;
                    data->m_matchConditions[cveId] = {"", MatchRuleCondition::DefaultStatus};
                    return true;
                }

                logDebug2(WM_VULNSCAN_LOGTAG,
                          "No match due to default status for OS: %s, Version: %s while scanning for Vulnerability: %s",
                          osCPE.product.c_str(),
                          data->osVersion().data(),
                          cveId.c_str());

                return false;
            }
            catch (const std::exception& e)
            {
                // Log the warning and continue with the next vulnerability.
                logDebug1(WM_VULNSCAN_LOGTAG,
                          "Failed to scan OS: '%s', CVE Numbering Authorities (CNA): '%s', Error: '%s'",
                          osCPE.product.c_str(),
                          cnaName.c_str(),
                          e.what());

                return false;
            }
        };

        try
        {
            if (data->osPlatform() == "windows" || data->osPlatform() == "darwin")
            {
                if (osCPE.product.empty())
                {
                    logDebug1(WM_VULNSCAN_LOGTAG,
                              "No CPE product found for OS '%s' on Agent '%s'.",
                              data->osName().data(),
                              data->agentId().data());
                }
                else
                {
                    PackageData package = {.name = osCPE.product, .vendor = {}, .format = {}, .version = {}};

                    data->m_vulnerabilitySource = std::make_pair(OS_SCANNER_CNA, OS_SCANNER_CNA);

                    if (TGlobalData::instance().vendorMaps().contains(ADP_DEFAULT_ARRAY_KEY))
                    {
                        auto& defaultCNAsArray = TGlobalData::instance().vendorMaps().at(ADP_DEFAULT_ARRAY_KEY);

                        for (const nlohmann::json& cna : defaultCNAsArray)
                        {
                            logDebug2(WM_VULNSCAN_LOGTAG,
                                      "Using CNA '%s' from CNA array for OS.",
                                      cna.get<std::string>().c_str());
                            m_databaseFeedManager->getVulnerabilitiesCandidates(cna, package, vulnerabilityScan);
                        }
                    }
                    else
                    {
                        logDebug2(
                            WM_VULNSCAN_LOGTAG, "No CNA array found for OS, using default CNA '%s'.", OS_SCANNER_CNA);

                        m_databaseFeedManager->getVulnerabilitiesCandidates(OS_SCANNER_CNA, package, vulnerabilityScan);
                    }

                    if (data->osPlatform() == "windows")
                    {
                        std::vector<std::string> cvesRemediated;

                        auto it = data->m_elements.begin();
                        while (it != data->m_elements.end())
                        {
                            const auto& cve = it->first;
                            FlatbufferDataPair<NSVulnerabilityScanner::RemediationInfo> remediations {};
                            m_databaseFeedManager->getVulnerabilityRemediation(cve, remediations);

                            if (remediations.data == nullptr || remediations.data->updates() == nullptr ||
                                remediations.data->updates()->size() == 0)
                            {
                                logDebug2(
                                    WM_VULNSCAN_LOGTAG,
                                    "No remediation available for OS '%s' on Agent '%s' for CVE: '%s', discarding.",
                                    osCPE.product.c_str(),
                                    data->agentId().data(),
                                    cve.c_str());
                                it = data->m_elements.erase(it);
                                continue;
                            }

                            for (const auto& remediation : *(remediations.data->updates()))
                            {
                                // Delete element if the update is already installed
                                if (std::find_if(responseHotfixes.begin(),
                                                 responseHotfixes.end(),
                                                 [&](const auto& element) {
                                                     return element.contains("hotfix") &&
                                                            element.at("hotfix") == remediation->str();
                                                 }) != responseHotfixes.end())
                                {
                                    logDebug2(WM_VULNSCAN_LOGTAG,
                                              "Remediation for OS '%s' on Agent '%s' has been found. CVE: '%s', "
                                              "Remediation: '%s'.",
                                              osCPE.product.c_str(),
                                              data->agentId().data(),
                                              cve.c_str(),
                                              remediation->str().c_str());
                                    cvesRemediated.push_back(cve);
                                    break;
                                }
                            }
                            ++it;
                        }

                        for (const auto& cve : cvesRemediated)
                        {
                            data->m_elements.erase(cve);
                            data->m_matchConditions.erase(cve);
                            data->m_cnaDetectionSource.erase(cve);
                        }
                    }
                }
            }
            else
            {
                logDebug1(WM_VULNSCAN_LOGTAG,
                          "OS scan for platform '%s' on Agent '%s' is not supported.",
                          data->osPlatform().data(),
                          data->agentId().data());
                return nullptr;
            }
        }
        catch (const std::exception& e)
        {
            logWarn(WM_VULNSCAN_LOGTAG,
                    "Failed to scan OS: '%s', CVE Numbering Authorities (CNA): 'nvd', Error: '%s'.",
                    osCPE.product.empty() ? data->osName().data() : osCPE.product.c_str(),
                    e.what());
        }

        logDebug1(WM_VULNSCAN_LOGTAG,
                  "Vulnerability scan for OS '%s' on Agent '%s' has completed.",
                  osCPE.product.empty() ? data->osName().data() : osCPE.product.c_str(),
                  data->agentId().data());

        return AbstractHandler<std::shared_ptr<TScanContext>>::handleRequest(std::move(data));
    }
    // LCOV_EXCL_STOP
};

using OsScanner = TOsScanner<>;

#endif // _OS_SCANNER_HPP
